{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a9a37999-b514-42b8-aed2-ef019dc481b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import findspark\n",
    "from pyspark.sql import SparkSession\n",
    "import pyspark.sql.functions as F\n",
    "import pyspark.sql.types as T\n",
    "import pandas as pd\n",
    "from functools import reduce\n",
    "\n",
    "findspark.init()\n",
    "\n",
    "spark = SparkSession.builder.master(\"local[*]\").appName(\"Recipes ML model 2\").config(\"spark.driver.memory\",\"8g\").getOrCreate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9080b45b-5a42-4e34-903e-ad01a27f4033",
   "metadata": {},
   "outputs": [],
   "source": [
    "file_path = r\"D:\\bigdata\\spark_ practice\\DataAnalysisWithPythonAndPySpark-Data-trunk\\recipes\\epi_r.csv\"\n",
    "food = spark.read.csv(file_path,inferSchema=True,header=True)\n",
    "\n",
    "def sanitize_column_name(name):\n",
    "    answer = name\n",
    "    for i,j in ((\" \",\"_\"),(\"-\",\"_\"),(\"&\",\"and\"),(\"/\",\"_\")):\n",
    "        answer = answer.replace(i,j)\n",
    "    return \"\".join([char for char in answer if char.isalpha() or char.isdigit() or char==\"_\"])\n",
    "\n",
    "food = food.toDF(*[sanitize_column_name(col) for col in food.columns])\n",
    "food = food.where(F.col(\"cakeweek\").isin([0.0,1.0])|F.col(\"cakeweek\").isNull()\n",
    "                 & F.col(\"wasteless\").isin([0.0,1.0])|F.col(\"wasteless\").isNull())\n",
    "\n",
    "IDENTIFIERS = [\"title\"]\n",
    "CONTINUOUS_COLUMNS = [\"rating\",\"calories\",\"protein\",\"fat\",\"sodium\"]\n",
    "TARGET_COLUMN = [\"dessert\"]\n",
    "BINARY_COLUMNS = [x for x in food.columns if x not in IDENTIFIERS and x not in  CONTINUOUS_COLUMNS and x not in TARGET_COLUMN]\n",
    "\n",
    "food = food.dropna(how=\"all\").dropna(subset=TARGET_COLUMN)\n",
    "\n",
    "\n",
    "from typing import Optional\n",
    "@F.udf(T.BooleanType())\n",
    "def is_a_number(value:Optional[str])->bool:\n",
    "    if not value:\n",
    "        return True\n",
    "    elif value.replace('.','').isnumeric():\n",
    "        return True\n",
    "    else:\n",
    "        return False\n",
    "\n",
    "food = food.where(\n",
    "    is_a_number(F.col(\"rating\"))&is_a_number(F.col(\"calories\"))\n",
    ").withColumns(\n",
    "    {col:F.col(col).cast(\"double\") for col in [\"rating\",\"calories\"]}\n",
    ")\n",
    "\n",
    "maximum = {\n",
    "    \"calories\":3184.0,\n",
    "    \"protein\":173.0,\n",
    "    \"fat\":207.0,\n",
    "    \"sodium\":5649.0\n",
    "}\n",
    "for k,v in maximum.items():\n",
    "    food = food.withColumn(k,F.when(F.isnull(F.col(k)),F.col(k)).otherwise(F.least(F.col(k),F.lit(v))))\n",
    "\n",
    "inst_sum_of_binary_columns = list(map(lambda x:F.sum(F.col(x)).alias(x),BINARY_COLUMNS))\n",
    "sum_of_binary_columns = food.select(inst_sum_of_binary_columns)\n",
    "sum_of_binary_columns = sum_of_binary_columns.head().asDict()\n",
    "num_rows = food.count()\n",
    "too_rare_features = [k for k,v in sum_of_binary_columns.items() if v<10 or v>num_rows-10]\n",
    "BINARY_COLUMNS = list(set(BINARY_COLUMNS) - set(too_rare_features))\n",
    "\n",
    "food = (\n",
    "    food.withColumn(\"protein_ratio\",F.col(\"protein\")*4/F.col(\"calories\"))\n",
    "    .withColumn(\"fat_ratio\",F.col(\"fat\")*9/F.col(\"calories\"))\n",
    ")\n",
    "ratio_set=[\"protein_ratio\",\"fat_ratio\"]\n",
    "food = food.fillna(0.0,subset=ratio_set)\n",
    "CONTINUOUS_COLUMNS.extend(ratio_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dfcfcba7-218e-4954-b94e-2e86aa92ab8c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20048 682\n"
     ]
    }
   ],
   "source": [
    "print(food.count(),len(food.columns))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "44658449-88a6-4da3-8c4e-675af8aa3807",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml.feature import MinMaxScaler\n",
    "from pyspark.ml.feature import VectorAssembler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a8104bb4-130c-44a4-a84c-aaac6fcd66fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "continuous_assembler = VectorAssembler(inputCols=CONTINUOUS_COLUMNS,outputCol=\"continuous\")\n",
    "continuous_scaler = MinMaxScaler(inputCol=\"continuous\",outputCol=\"continuous_scaled\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d9b02b36-8fa1-40c5-80dd-7ae5c2c760fe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Param(parent='VectorAssembler_f24002d4c5af', name='outputCol', doc='output column name.')"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "continuous_assembler.outputCol                                                                                                                                                                                                                                                   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9365c966-a9e4-4351-9318-f2eb82598fe3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'continuous'"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "continuous_assembler.getOutputCol()     "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9a5dcf0b-b54a-444d-8849-5c20a9c2992f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'outputCol: output column name. (default: VectorAssembler_f24002d4c5af__output, current: continuous)'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "continuous_assembler.explainParam(\"outputCol\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "112dd31e-dafd-4d56-bf7d-9db1413b484e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'more_continuous'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "continuous_assembler.setOutputCol(\"more_continuous\")\n",
    "continuous_assembler.getOutputCol()     "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3d07f019-56e1-4798-90f5-5bb3a20e96fc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['rating',\n",
       " 'calories',\n",
       " 'protein',\n",
       " 'fat',\n",
       " 'sodium',\n",
       " 'protein_ratio',\n",
       " 'fat_ratio']"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "continuous_assembler.getInputCols()  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ed4190a7-e521-42dc-bd44-581e48155571",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "handleInvalid: How to handle invalid data (NULL and NaN values). Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the output). Column lengths are taken from the size of ML Attribute Group, which can be set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'). (default: error, current: skip)\n",
      "inputCols: input column names. (current: ['one', 'two', 'three'])\n",
      "outputCol: output column name. (default: VectorAssembler_f24002d4c5af__output, current: more_continuous)\n"
     ]
    }
   ],
   "source": [
    "continuous_assembler.setParams(inputCols=[\"one\",\"two\",\"three\"],handleInvalid=\"skip\")\n",
    "print(continuous_assembler.explainParams())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2b710b2c-70fe-413e-b2a3-9271fdb0c396",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "handleInvalid: How to handle invalid data (NULL and NaN values). Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the output). Column lengths are taken from the size of ML Attribute Group, which can be set using `VectorSizeHint` in a pipeline before `VectorAssembler`. Column lengths can also be inferred from first rows of the data since it is safe to do so but only in case of 'error' or 'skip'). (default: error)\n",
      "inputCols: input column names. (undefined)\n",
      "outputCol: output column name. (default: VectorAssembler_f24002d4c5af__output, current: more_continuous)\n"
     ]
    }
   ],
   "source": [
    "continuous_assembler.clear(continuous_assembler.handleInvalid)\n",
    "continuous_assembler.clear(continuous_assembler.inputCols)\n",
    "print(continuous_assembler.explainParams())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "be83e37a-3587-4ae9-baae-7ac1dffc290e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['rating', 'calories', 'protein', 'fat', 'sodium', 'protein_ratio', 'fat_ratio']\n"
     ]
    }
   ],
   "source": [
    "continuous_assembler.setInputCols(CONTINUOUS_COLUMNS)\n",
    "print(continuous_assembler.getInputCols())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ba0dd9b2-0859-40a3-aaf9-25415d9c744f",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.ml import Pipeline\n",
    "import pyspark.ml.feature as MF\n",
    "\n",
    "imputer = MF.Imputer(strategy=\"mean\",inputCols=[\"calories\",\"protein\",\"fat\",\"sodium\"],\n",
    "                     outputCols=[\"calories_i\",\"protein_i\",\"fat_i\",\"sodium_i\"])\n",
    "\n",
    "continuous_assembler = MF.VectorAssembler(inputCols=[\"rating\",\"calories_i\",\"protein_i\",\"fat_i\",\"sodium_i\"],\n",
    "                                         outputCol = \"continuous\")\n",
    "\n",
    "continuous_scaler = MF.MinMaxScaler(inputCol=\"continuous\",outputCol=\"continuous_scaled\")\n",
    "\n",
    "food_pipeline = Pipeline(stages = [imputer,continuous_assembler,continuous_scaler])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ace46ea7-7891-4900-b2c6-b55df8cca048",
   "metadata": {},
   "outputs": [],
   "source": [
    "preml_assembler = MF.VectorAssembler(inputCols=BINARY_COLUMNS + [\"continuous_scaled\"] + [\"protein_ratio\",\"fat_ratio\"],\n",
    "                                    outputCol = \"features\")\n",
    "food_pipeline.setStages([imputer,continuous_assembler,continuous_scaler,preml_assembler])\n",
    "\n",
    "food_pipeline_model = food_pipeline.fit(food)\n",
    "food_features = food_pipeline_model.transform(food)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2e4c5a67-5771-4fb3-a1df-9aff27caa212",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------+-------+--------------------+\n",
      "|               title|dessert|            features|\n",
      "+--------------------+-------+--------------------+\n",
      "|Lentil, Apple, an...|    0.0|(513,[25,85,110,1...|\n",
      "|Boudin Blanc Terr...|    0.0|(513,[80,100,111,...|\n",
      "|Potato and Fennel...|    0.0|(513,[63,133,143,...|\n",
      "|Mahi-Mahi in Toma...|    0.0|(513,[100,132,140...|\n",
      "|Spinach Noodle Ca...|    0.0|(513,[5,100,113,1...|\n",
      "|      The Best Blts |    0.0|(513,[25,100,110,...|\n",
      "|Ham and Spring Ve...|    0.0|(513,[30,35,41,77...|\n",
      "|Spicy-Sweet Kumqu...|    0.0|(513,[85,111,133,...|\n",
      "|Korean Marinated ...|    0.0|(513,[20,35,59,11...|\n",
      "|Ham Persillade wi...|    0.0|(513,[41,71,89,99...|\n",
      "|Yams Braised with...|    0.0|(513,[100,134,190...|\n",
      "|  Spicy Noodle Soup |    0.0|(513,[18,24,35,48...|\n",
      "|Banana-Chocolate ...|    1.0|(513,[4,25,53,85,...|\n",
      "|Beef Tenderloin w...|    0.0|(513,[48,100,103,...|\n",
      "|      Peach Mustard |    0.0|(513,[27,100,284,...|\n",
      "|Raw Cream of Spin...|    0.0|(513,[48,111,133,...|\n",
      "|Sweet Buttermilk ...|    1.0|(513,[19,23,28,60...|\n",
      "|Crisp Braised Por...|    0.0|(513,[133,143,192...|\n",
      "|Mozzarella-Topped...|    0.0|(513,[100,132,167...|\n",
      "|Tuna, Asparagus, ...|    0.0|(513,[27,30,35,10...|\n",
      "+--------------------+-------+--------------------+\n",
      "only showing top 20 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "food_features.select(\"title\",\"dessert\",\"features\").show(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "0e9a7f90-73f7-4537-983a-162ab5d73eee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(food_features.schema[\"features\"].metadata)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "188c26c4-2cab-4fa6-8d92-5c37d038eea6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Pipeline_5be98d6b2709"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.classification import LogisticRegression\n",
    "\n",
    "lr = LogisticRegression(featuresCol=\"features\",labelCol=\"dessert\",predictionCol=\"prediction\")\n",
    "food_pipeline.setStages([imputer,continuous_assembler,continuous_scaler,preml_assembler,lr])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "669b8db4-b22c-4a23-b8b2-5074c4deeefb",
   "metadata": {},
   "outputs": [],
   "source": [
    "train,test = food.randomSplit([0.7,0.3],13)\n",
    "train.cache()\n",
    "food_pipeline_model = food_pipeline.fit(train)\n",
    "results = food_pipeline_model.transform(test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "2f7a3d2e-5025-406b-8c94-aa3c975b8be2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------+----------------------------------------+-------------------------------------------+\n",
      "|prediction|rawPrediction                           |probability                                |\n",
      "+----------+----------------------------------------+-------------------------------------------+\n",
      "|1.0       |[-13.828674499765471,13.828674499765471]|[9.8692134991871E-7,0.9999990130786501]    |\n",
      "|0.0       |[13.611382870191992,-13.611382870191992]|[0.9999987735467576,1.226453242408887E-6]  |\n",
      "|0.0       |[22.816580607506552,-22.816580607506552]|[0.9999999998767222,1.2327783238674783E-10]|\n",
      "+----------+----------------------------------------+-------------------------------------------+\n",
      "only showing top 3 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results.select(\"prediction\",\"rawPrediction\",\"probability\").show(3,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "ece26a91-ac68-45e9-b91a-72e45b970e61",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "产品            A    B\n",
      "日期                  \n",
      "2023-01-01  100  150\n",
      "2023-01-02  120  180\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# 创建示例数据\n",
    "df = pd.DataFrame({\n",
    "    '日期': ['2023-01-01', '2023-01-01', '2023-01-02', '2023-01-02'],\n",
    "    '产品': ['A', 'B', 'A', 'B'],\n",
    "    '销量': [100, 150, 120, 180],\n",
    "    '价格': [10, 15, 12, 16]\n",
    "})\n",
    "\n",
    "# 使用pivot_table进行汇总\n",
    "result = pd.pivot_table(df, values='销量', index='日期', columns='产品', aggfunc='sum')\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "3835ec21-a31c-4036-aa58-130796d80f44",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+-------+----+---+\n",
      "|dessert| 0.0|1.0|\n",
      "+-------+----+---+\n",
      "|    0.0|4988| 86|\n",
      "|    1.0|  92|970|\n",
      "+-------+----+---+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results.groupby(\"dessert\").pivot(\"prediction\").count().show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "f4dff4a9-925d-4f46-9d74-f7d347a5a837",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<pyspark.ml.classification.BinaryLogisticRegressionSummary at 0x217cf2fe520>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lr_model = food_pipeline_model.stages[-1]\n",
    "metrics = lr_model.evaluate(results.select(\"dessert\",\"features\"))\n",
    "metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "187b1873-9dc8-4850-9400-6e2a8bb691f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9709908735332464"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics.accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "1e3bc034-6050-4a7e-8198-150c84e89305",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9185606060606061"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics.precisionByLabel[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "b9ca72bf-3bfb-4015-ad59-3510a592abc9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9133709981167608"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics.recallByLabel[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "ecc1359c-8216-4bc9-b675-064a5318be8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# from pyspark.mllib.evaluation import MulticlassMetrics\n",
    "\n",
    "# predictionAndLabel = results.select(\"prediction\",\"dessert\").rdd\n",
    "# metrics_rdd = MulticlassMetrics(predictionAndLabel)\n",
    "\n",
    "# metrics_rdd "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "1ebe2e4e-3e99-417f-bfe2-6da3528f248e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# metrics_rdd.precision(1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "0ff43a1d-3e0f-460f-93de-1248e42ab0a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# metrics_rdd.recall(1.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "69a146a7-5296-4126-9d5c-9eb3c05890cb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+--------------------------------------------------------+-------+----------------------------------------+\n",
      "|title                                                   |dessert|rawPrediction                           |\n",
      "+--------------------------------------------------------+-------+----------------------------------------+\n",
      "|\"\"\"Cannoli\"\" Ice Cream Sandwiches \"                     |1.0    |[-13.828674499765471,13.828674499765471]|\n",
      "|\"\"\"Virgin Mary\"\" Aspic \"                                |0.0    |[13.611382870191992,-13.611382870191992]|\n",
      "|\"Pasta with Lobster, Tomatoes and \"\"Herbes de Maquis\"\" \"|0.0    |[22.816580607506552,-22.816580607506552]|\n",
      "|Acini di Pepe Pasta with Garlic and Olives              |0.0    |[18.41025799788323,-18.41025799788323]  |\n",
      "|Acorn Squash with Kale and Sausage                      |0.0    |[23.666443400495062,-23.666443400495062]|\n",
      "+--------------------------------------------------------+-------+----------------------------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "results.select(\"title\",\"dessert\",\"rawPrediction\").show(5,False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "48aaf24e-ed7a-4652-830b-5e84c833bc31",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9934616452399035"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.evaluation import BinaryClassificationEvaluator\n",
    "\n",
    "evaluator = BinaryClassificationEvaluator(labelCol=\"dessert\",rawPredictionCol =\"rawPrediction\", metricName = \"areaUnderROC\")\n",
    "evaluator.evaluate(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "270ae96b-ff08-4be8-92b4-f40b011382c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+---+--------------------+\n",
      "|FPR|                 TPR|\n",
      "+---+--------------------+\n",
      "|0.0|                 0.0|\n",
      "|0.0|0.004380724810832338|\n",
      "|0.0|0.008761449621664676|\n",
      "|0.0|0.013142174432497013|\n",
      "|0.0|0.017921146953405017|\n",
      "+---+--------------------+\n",
      "only showing top 5 rows\n",
      "\n"
     ]
    }
   ],
   "source": [
    "lr_model.summary.roc.show(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "9c8ce567-8f89-4c72-8a4f-da44e0a12613",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAcoAAAHACAYAAAAiByi6AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/H5lhTAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA790lEQVR4nO3deVxVdf7H8fcF2VxAzcCN1GrKzNLSZHAZU5kozXKmJkpHHTPbrJ9FNWmLZpY4jZrN5GiZtptmaZmaG8aUZVkq5aTZmJpmgpoJiMp2z++PExAjIFzuvd+7vJ6Px33EOZwDH05433zP+S4Oy7IsAQCASoWYLgAAAF9GUAIAUA2CEgCAahCUAABUg6AEAKAaBCUAANUgKAEAqAZBCQBANeqZLsDbnE6nfvzxRzVq1EgOh8N0OQAAQyzLUl5enlq2bKmQkKrbjUEXlD/++KPi4+NNlwEA8BH79u1T69atq/x80AVlo0aNJNkXJjo62nA1AABTcnNzFR8fX5YLVQm6oCy93RodHU1QAgBO+xiOzjwAAFSDoAQAoBoEJQAA1SAoAQCoBkEJAEA1CEoAAKpBUAIAUA2CEgCAahCUAABUg6AEAKAaRoPyww8/1MCBA9WyZUs5HA698847pz0nIyNDl156qSIiInTuuefqpZde8nidAIDgZTQo8/Pz1alTJ82cObNGx+/evVsDBgxQnz59lJmZqXvuuUe33HKLVq1a5eFKAQDByuik6FdddZWuuuqqGh8/e/ZstWvXTtOmTZMkXXDBBVq/fr2efvppJScne6pMAEAQ86vVQzZs2KCkpKQK+5KTk3XPPfdUeU5BQYEKCgrKtnNzcz1Vnt8oLHbqyx+OqrDYKUkqcVrKyjmpwhKnDuaeVEGxU0fyC5V7sqjsHKclZeeeLDsHAEx75sZLdH7z6pfIcge/CsqsrCzFxcVV2BcXF6fc3FydOHFCUVFRp5yTlpamiRMneqtEtylxWvps908qKHaqoMipvUfy9fWPuco7WXzKsccLi3Uwr6CSr1K5w3kFyq3k6wCAPzlZVOKV7+NXQemKcePGKTU1tWy7dKFOX2FZlgqKnTqQc1KHjxVo+4FcfXfwmBZv2V9pKLpLs4bhatYwomw7OipMjaPCVD88VGc2ilBYaIhaxEQqJKR8nbaGEfV0RoOIyr4cAHiW5ZSysqUWLcp2tTuzgVe+tV8FZfPmzZWdnV1hX3Z2tqKjoyttTUpSRESEIiJ85809c99Rrfo6Szuy8vTDz8d14OhJ5RVUHYjtmjVQw4h6im0UodjoSHVo0UgRYaEVjglxONQyJlL1QmvWNyuiXog6topRaEj1i5UCgE9wOqXbbpMWLZLWrpW6dvXqt/eroExMTNSKFSsq7FuzZo0SExMNVVRz6/97WNPW7NCWvUcr/XxYqEPRkWE6s1GELmwZo87xMbr8/FjFN63v3UIBwJc4ndKtt0pz50ohIdLOncEVlMeOHdPOnTvLtnfv3q3MzEw1bdpUZ511lsaNG6f9+/frlVdekSTdfvvtevbZZ/XXv/5VN998s9atW6c333xTy5cvN/UjVCs796SeXbdTX/1wVF/+kFO2v1PrGF19cUuddUZ9tW4SpVaNo9Qgop7CatgiBICg4HRKo0ZJ8+bZIfnqq9KNN3q9DKNB+cUXX6hPnz5l26XPEocPH66XXnpJBw4c0N69e8s+365dOy1fvlz33nuvnnnmGbVu3VovvPCCTw4NWf/fw7pvUaayc8s72Qzq3FI392ynC1ty2xMAqlVSIt1yi/TSS3ZIvv66kZCUJIdlWZaR72xIbm6uYmJilJOTo+joaI98j8/3HNENz22QZUmtm0Tptt+dzW1UAKipkhJp5Ejp5Zel0FA7JFNS3P5tapoHfvWM0h9YlqXpq7+VZUk9z22mp1M668xGvtOZCAB8XlGRtH+/HZLz50s33GC0HILSzZZ9dUAbdv2kEIc0aVBHQhIAaisyUnr3XenTT6W+fU1Xw+oh7nQkv1DjFm+VJP2pS7zaNfPOGB8A8HvFxdLChVLp08D69X0iJCWC0q1WbD2gYwXFanNGfT1y9QWmywEA/1BcLA0bZnfWGT/edDWn4NarmxSXODV3/W5J0pCEs9QoMsxwRQDgB4qLpT//2W5N1qsndeliuqJTEJRusnpbtnYfzlfTBuEanNDGdDkA4PuKi6UhQ6Q335TCwuyZd6691nRVpyAo3WT+Z/Z4z4EXt1DDCC4rAFSrqMgOyUWL7JB8+21p4EDTVVWKZ5RusOvQMa3feViSNOiSVoarAQAfZ1nlIRkeLi1e7LMhKRGUbvFO5o+SpN7nnalLzmpiuBoA8HEOh5ScbA8DWbxYuvpq0xVVi3uEbpC576gk6YoL46o/EABgGzlSuvJKqZXv34WjRekGP/x8XJLU7gzGTQJApQoLpfvukw4eLN/nByEpEZR1ZlmW9v98QpLUuglzuQLAKQoLpT/9SZo+XRowwF4VxI8QlHX03aFjKih2Kjw0RM1jIk2XAwC+paBAuv56aelS+5nkk0/aq4H4EZ5R1tEH3xySJCWc3VTh9fzrfz4AeFRBgXTdddLy5XZILl0q/f73pquqNYKyjj7b/ZMku8crAOAXJ0/aIblihRQVJb33ntSvn+mqXEJQ1tGhPHth5rZ05AGAcmPGlIfksmU+M8G5K7hXWEdHjhdKkpo0CDdcCQD4kHHjpA4d7NuufhySEi3KOrEsq6xF2awhQQkgyFmWPZmAJLVtK331lb34sp+jRVkHuSeLdbLI7uYcF02PVwBB7MQJexq6JUvK9wVASEoEZZ0czD0pSYqJClNkWGD8QgBArR0/Ll1zjX2b9eabpZwc0xW5Fbde6yA7177tGtsownAlAGBIaUimp0sNGthDQGJiTFflVgRlHWT/0qLktiuAoJSfb99u/eADqWFD6f33pZ49TVfldgRlHWTn2UEZG02LEkCQyc+3V/3IyLBDcuVKqUcP01V5BM8o6+DgL7deaVECCDrPP2+HZKNG0qpVARuSEi3KOtl/1J4MPY5nlACCzZgx0vffSykpUmKi6Wo8iqCsg+8OHpMknRfXyHAlAOAF+flSeLgUFmZPbD5jhumKvIJbr3VQUGyPoWwQwd8bAAJcXp690PKQIVJRkelqvIp3+DooKrGDMiyUvzcABLC8POmqq6SPP7aHfuzaJZ1/vumqvIZ3+DooD0qH4UoAwENyc+2W5McfS40bS2vXBlVISrQo66S4xJJEixJAgMrJsUPy00+lJk2kNWukLl1MV+V1BGUdFP7SoqxHixJAoMnJkZKTpc8+s0Ny7Vrp0ktNV2UETaE6KHbaLcpwWpQAAs3XX0tffik1bWpPTxekISnRonSZ02mp5JegrEdQAgg03bvbCy6fcYbUubPpaowiKF1U5HSWfcytVwAB4eefpYMHyzvr9Otnth4fQVPIRaUdeSRuvQIIAEeOSElJUu/e0vbtpqvxKbzDu6h0aIgk1QuhRQnAj5WG5ObNktMplZSYrsinEJQuKvqlRelwSKEEJQB/9dNP9i3WLVuk2Fh7yayOHU1X5VN4RumisskGQkLkcBCUAPzQTz/ZLcnMzPKQ7NDBdFU+hxali8onGyAkAfih0pZkZqYUF0dIVoMWpYvKJxvgbw0AfqhePSkiQmre3A7J9u1NV+SzCEoXFTuZ5xWAH4uJsRdcPnRI+s1vTFfj02gOuYh5XgH4nYMHpRdfLN9u3JiQrAFalC5inlcAfiU7W+rbV9q2TSookG6/3XRFfoPmkItoUQLwG1lZUp8+dki2bMmMO7XEu7yLfj08BAB81oEDdkhu3y61aiVlZHC7tZa49eqisqCsx61XAD6qNCR37JBat7Z7t557rumq/A7NIReVzsxTjxYlAF90/Hh5SMbH2y1JQtIlvMu7qLiE4SEAfFj9+tLIkdJZZ9khec45pivyWwSliwrLgpJLCMBHPfCA9NVX0tlnm67Er/Eu76LSXq/MzAPAZ+zfLw0ZIuXmlu+LiTFXT4CgM4+LSmfmCefWKwBfsG+f/Uzyu+/sZbIWLDBdUcCgOeSiQjrzAPAVe/dKl19uh2S7dtLf/ma6ooDCu7yLyjrz1OMSAjDo++/tkNy1y34WmZEhtWljuqqAwru8i8onHODWKwBD9uyxQ3L3brtXa0aG3csVbsUzShcVMYUdAJMsS7rpJjssS0OydWvTVQUk3uVdVMSk6ABMcjikuXOlnj2lf/+bkPQgWpQuYlJ0AEYUF9uLLktShw7Shx/aoQmP4V3eRUXMzAPA23btkjp2lNatK99HSHocQemiIiYcAOBN331nd9zZscOeceeXsdzwPOPv8jNnzlTbtm0VGRmphIQEbdy4sdrjZ8yYofPPP19RUVGKj4/Xvffeq5MnT3qp2nKlEw5w6xWAx+3caYfkvn1S+/bS8uUSY7i9xuiVXrhwoVJTUzVhwgRt3rxZnTp1UnJysg4ePFjp8fPnz9fYsWM1YcIEbd++XXPnztXChQv10EMPeblyhocA8JL//tcOyR9+kC64wF4qq3lz01UFFaNBOX36dI0aNUojRoxQhw4dNHv2bNWvX1/z5s2r9PhPPvlEPXr00ODBg9W2bVtdccUVuummm07bCvWEsuEhTDgAwFNKQ3L/frvjDiFphLF3+cLCQm3atElJSUnlxYSEKCkpSRs2bKj0nO7du2vTpk1lwbhr1y6tWLFC/fv3r/L7FBQUKDc3t8LLHcqGh9CiBOApM2ZIP/4oXXihHZJxcaYrCkrGhoccPnxYJSUlivuf//FxcXH65ptvKj1n8ODBOnz4sHr27CnLslRcXKzbb7+92luvaWlpmjhxoltrlxgeAsALZsyQGjaU7rtPio01XU3Q8qt3+YyMDE2ePFn/+te/tHnzZi1evFjLly/XpEmTqjxn3LhxysnJKXvt27fPLbUUMuEAAE/Yv7+8R2tYmD3BOSFplLEWZbNmzRQaGqrs7OwK+7Ozs9W8invwjz76qIYOHapbbrlFknTRRRcpPz9ft956qx5++GGFVNILLCIiQhEREW6vv5iFmwG42/bt9lJZ114rzZpFz1YfYez/Qnh4uLp06aL09PSyfU6nU+np6UpMTKz0nOPHj58ShqGhoZIky7I8V2wlyud6pUUJwA22bbM77mRnS599Jh07Zroi/MLoFHapqakaPny4unbtqm7dumnGjBnKz8/XiBEjJEnDhg1Tq1atlJaWJkkaOHCgpk+frksuuUQJCQnauXOnHn30UQ0cOLAsML2liBYlAHf5+mupb1/p4EGpc2dp7VopOtp0VfiF0aBMSUnRoUOHNH78eGVlZalz585auXJlWQefvXv3VmhBPvLII3I4HHrkkUe0f/9+nXnmmRo4cKCefPJJr9de3uuVoARQB//5jx2Shw5Jl1xih2TTpqarwq84LG/fszQsNzdXMTExysnJUXQd/mK79tn1+vKHHL0wrKuSOtBlG4ALtm61Q/LwYenSS6U1awhJL6ppHtAccpHzlz8vQhlHCcBVu3ZJR49KXbrQkvRhLLPlIucvDXEm7gfgsmuvtedtvewyqUkT09WgCgSli0p+aVLSogRQK19+KTVuLLVpY29fcYXRcnB63Hp1UemT3RCalABqavNme5xknz72SiDwCwSli7j1CqBWNm2SkpKkn3+252xl+IffIChdVBqUoSQlgNP54ovykExMlFatkmJiTFeFGiIoXVTa6zWEZ5QAqvP559Lvf2/3bu3eXVq5ktakn6Ezj4tKW5TkJIAqbd5sh2ROjtSjh/T++1KjRqarQi0RlC4qD0qSEkAVWreWWrWSLrpIWrGCkPRTBKWLSlfBcRCUAKoSG2svuFy/vr2uJPwSzyhdRGceAJXasEF6+eXy7dhYQtLP0aJ0EcNDAJzik0+kK6+0l8hq1kwaMMB0RXADWpQuYq5XABV8/LGUnCzl5dnrSl5+uemK4CYEpYssOvMAKLV+fXlLsm9fadkyqUED01XBTQhKF5WNoyQngeD20UflIdmvn/Tee3bnHQQMgtJF5c8oSUogaO3aJV11lZSfb4+XJCQDEp15XOR00pkHCHrt2kmjR0uZmdI770hRUaYrggcQlC4qXT2E4SFAEHM4pClTpKIiKTzcdDXwEG69uoiZeYAgtW6dNGiQdOKEve1wEJIBjqB0UQnjKIHgk54uXX219O670lNPma4GXkJQuojVQ4Ags3atHZInTtgTCYwda7oieAlB6SKL1UOA4LFmjTRwoHTypB2Wb78tRUSYrgpeQlC6qHwcJUkJBLRVq8pDcuBA6a23CMkgQ1C6iLlegSBw/Lg0fLhUUCBdey0hGaQIShdYllU2PIQWJRDA6teXli6Vhg2T3nyT3q1BinGULigNSYmgBAJSfn75XK3dutkvBC1alC4o+VVSMuEAEGCWLZPOPlv6/HPTlcBHEJQucP66SUlOAoHjvfekP/5ROnhQmj3bdDXwEQSlCyreejVXBwA3WrpUuu46ezq6P/2JoEQZgtIFPKMEAsy770rXX2+HZEqKNH++FBZmuir4CILSBb++9UpQAn5uyZLykLzxRum116R69HNEOYLSBb8OSnIS8GOWJb3wglRcLN10k/Tqq4QkTkFQusA6/SEA/IHDYU8i8NRT0iuvEJKoFEHpAp5RAn5u69byf8hRUdIDDxCSqBJBCSC4LFokXXKJ9NBDFf/qBapAUAIIHgsX2s8iS0qkAwcIStQIQQkgOLzxhjR4sB2SI0ZIc+dKIbwF4vT4LQEQ+ObPl/78Z8nplG6+2e7pGhpquir4CYISQGB77TVp6FA7JEeOlObMoSWJWuG3BUBgKyy0Q/KWW6TnnyckUWv0hwYQ2G6+WfrNb6QePQhJuITfGgCB5+237RVASvXqRUjCZfzmuIIe5YDvmjfPXv2jb18pJ8d0NQgABGUdMTEP4EPmzrWfRVqWdPnlUnS06YoQAAhKAIFhzpzykLz7bumf/+QvWbgFQQnA/z3/vHTrrfbHY8ZIzzxDSMJtCEoA/u2116TbbrM/vuce6emnCUm4FcNDAPi3nj2lNm2kP/5RmjaNkITbEZQA/FvbttKmTVLTpoQkPIJbrwD8z8yZ0uLF5dtnnEFIwmNoUQLwL//4h91hp149acsWqWNH0xUhwNGiBOA/nnnGDklJuv9+6cILzdaDoEBQusBiah7A+55+2u7VKkkPPSRNnsztVngFQVlH/DMFvGD6dCk11f744YelJ54gJOE1BCUA37Z6tXTfffbHjz4qTZpESMKr6MwDwLclJdlT07VqJT32mOlqEIQISgC+yem0l8YKCbGnqKMVCUO49QrA96SlSSkpUlGRvU1IwiCCEoBvmTzZ7tX61lvS0qWmqwHMB+XMmTPVtm1bRUZGKiEhQRs3bqz2+KNHj2r06NFq0aKFIiIidN5552nFihVeqhaARz3xhN2rVbI77Vx3ndl6ABl+Rrlw4UKlpqZq9uzZSkhI0IwZM5ScnKwdO3YoNjb2lOMLCwv1+9//XrGxsXrrrbfUqlUrff/992rcuLH3iwfgXo8/Lk2YYH/85JN2qxLwAUaDcvr06Ro1apRGjBghSZo9e7aWL1+uefPmaezYsaccP2/ePB05ckSffPKJwsLCJElt27b1ZskAPGHixPIerWlpUiX//gFTjN16LSws1KZNm5SUlFReTEiIkpKStGHDhkrPWbp0qRITEzV69GjFxcWpY8eOmjx5skpKSqr8PgUFBcrNza3wqiuLiXkA99m9W5oyxf54yhRCEj7HWFAePnxYJSUliouLq7A/Li5OWVlZlZ6za9cuvfXWWyopKdGKFSv06KOPatq0aXriiSeq/D5paWmKiYkpe8XHx7v153DQGw+om3btpGXL7LUkH3zQdDXAKYx35qkNp9Op2NhYPf/88+rSpYtSUlL08MMPa/bs2VWeM27cOOXk5JS99u3b58WKAVTKsqSDB8u3+/Urn6IO8DHGgrJZs2YKDQ1VdnZ2hf3Z2dlq3rx5pee0aNFC5513nkJDQ8v2XXDBBcrKylJhYWGl50RERCg6OrrCC4BBlmVPRXfxxdL27aarAU7LWFCGh4erS5cuSk9PL9vndDqVnp6uxMTESs/p0aOHdu7cKafTWbbv22+/VYsWLRQeHu7xmgHUkWXZwz+efFLKzpY+/NB0RcBpGb31mpqaqjlz5ujll1/W9u3bdccddyg/P7+sF+ywYcM0bty4suPvuOMOHTlyRGPGjNG3336r5cuXa/LkyRo9erSpHwFATVmWPeQjLc3efuYZ6bbbzNYE1IDR4SEpKSk6dOiQxo8fr6ysLHXu3FkrV64s6+Czd+9ehYSUZ3l8fLxWrVqle++9VxdffLFatWqlMWPG6EE6AAC+zbLs3qxPPWVv/+Mf0t13m60JqCGHZQXXYIfc3FzFxMQoJyfH5eeVP+cX6pJJayRJ303ur9AQer4CVbIs6a9/laZOtbeffVbiLhB8QE3zgNVDAHjWyZPSRx/ZH8+cKd15p9l6gFoiKAF4VlSUtGqVvQDzn/5kuhqg1vxqHCUAP2FZ0rp15dsxMYQk/BZB6YKgeqgL1JZlSWPG2JMITJtmuhqgzrj1Wkd04wF+xbKk//s/u8OOwyGxsg8CAEEJwD0sS7rrLulf/7JD8oUXpJtvNl0VUGcEJYC6czrtkJw1yw7JefOkv/zFdFWAWxCUAOrGsuxxkbNn2yH54ovS8OGmqwLchqAEUDcOh3TOOVJIiPTSS9LQoaYrAtyKoARQd/ffL/XvL3XoYLoSwO0YHgKg9pxOacoUKSenfB8hiQBFUAKoHadTuuUWadw4acAAexsIYAQlgJorKZFGjrQ77ISE2D1dQ3gbQWDjGaULgmzBFcBWUmKPi3zlFSk0VHr9dSklxXRVgMcRlHXkYGoeBIOSEmnECOnVV+2QfOMN5m5F0CAoAZzeffeVh+SCBdL115uuCPAaHi4AOL1bb5Vat5YWLiQkEXTcFpSLFy/WxRdf7K4vB8CXdOggffutdN11pisBvK5WQfncc8/p+uuv1+DBg/XZZ59JktatW6dLLrlEQ4cOVY8ePTxSJAAvKy62O+6kp5fvi4oyVw9gUI2DcsqUKbr77ru1Z88eLV26VH379tXkyZM1ZMgQpaSk6IcfftCsWbM8WSsAbyguloYMsYeAXHeddPSo6YoAo2rcmefFF1/UnDlzNHz4cH300Ufq3bu3PvnkE+3cuVMNGjTwZI0AvKWoyA7JRYuksDB7KAhrSiLI1Tgo9+7dq759+0qSevXqpbCwME2cOJGQBAJFUZF0003S229L4eH2f6++2nRVgHE1DsqCggJFRkaWbYeHh6tp06YeKQqAlxUWSjfeKC1ZYofk4sX29HQAajeO8tFHH1X9+vUlSYWFhXriiScUExNT4Zjp06e7rzofxbw8CDgzZ5aH5JIl9kogACTVIih/97vfaceOHWXb3bt3165duyoc4wjCaWqC8WdGALrrLmnTJvv55FVXma4G8Ck1DsqMjAwPlgHA6woLpXr17EnNw8Kk114zXRHgk2p16zU3N1efffaZCgsL1a1bN5155pmeqguAJxUU2HO1Nm8uzZ7NCiBANWoclJmZmerfv7+ysrIkSY0aNdKbb76p5ORkjxUHwAMKCuxp6JYtkyIjpTFjpAsvNF0V4LNq/Gfkgw8+qHbt2unjjz/Wpk2b1K9fP911112erA2Au508Kf3xj+Uh+d57hCRwGjVuUW7atEmrV6/WpZdeKkmaN2+emjZtqtzcXEVHR3usQABuUhqS779vT0f33ntSv36mqwJ8Xo1blEeOHFHr1q3Lths3bqwGDRrop59+8khhANzo5EnpD38oD8nlywlJoIZq1Zln27ZtZc8oJcmyLG3fvl15eXll+1hBBPBBGzdKa9dK9evbIXn55aYrAvxGrYKyX79+sqyKw+2vvvpqORwOWZYlh8OhkpIStxYIwA1+9zt7LckzzpB69zZdDeBXahyUu3fv9mQdfsViah74g+PHpSNH7AWXJfv5JIBaq3FQvvzyy7r//vvLprAD4MOOH5euuUb67jspI0Nq08Z0RYDfqnFnnokTJ+rYsWOerAWAOxw/Lg0caC+6fPiw9OOPpisC/FqNg/J/n00C8EH5+fbSWOvWSY0aSatWSYmJpqsC/FqtOvMwATjgw/Lz7aWx/v1vQhJwo1oF5XnnnXfasDxy5EidCgLggmPH7JD88EMpOtoOyd/+1nRVQECoVVBOnDjxlPUnAfiAEyekn36yQ3L1aikhwXRFQMCoVVDeeOONio2N9VQtAFx15pn2c8l9+6QuXUxXAwSUGnfm4fkk4GNyc6V33y3fjo0lJAEPoNcr4I9yc6Urr7Tnb33lFdPVAAGtxrdenU6nJ+vwK5b4owEG5eTYIfnpp1KTJiyTBXhYrZ5RoiLuRsPrcnKk5GTps8/skFy7Vvpl6TsAnlHjW68ADDt6VLriCjskmza1Z94hJAGPo0UJ+IMTJ+yQ/Pzz8pDs3Nl0VUBQoEUJ+IPISKlvX3uZrHXrCEnAiwhKwB84HFJamvTVV1KnTqarAYIKQQn4qiNHpHvusW+7SnZYtmxptCQgGPGMEvBFP/0kJSVJmZn2UlmvvWa6IiBo0aIEfM3hw1K/fnZIxsZK48aZrggIarQoAV9SGpJffSXFxdkddzp0MF0VENQISlcwMQ884dAhOyS3brVD8oMPpAsuMF0VEPS49VoHTMwDt7Es6brr7JBs3lzKyCAkAR9BUAK+wOGQpk2zwzEjQ2rf3nRFAH7BrVfAJMsqnzT4ssvsFmVoqNmaAFRAixIwJTtb6tnTnpauFCEJ+ByCEjAhK0vq00f65BPp5psllrEDfJZPBOXMmTPVtm1bRUZGKiEhQRs3bqzReQsWLJDD4dCgQYM8WyDgTgcO2CG5fbvUurW0ZIkU4hP/FAFUwvi/zoULFyo1NVUTJkzQ5s2b1alTJyUnJ+vgwYPVnrdnzx7df//96tWrl5cqBdygNCS/+UaKj7c77px7rumqAFTDeFBOnz5do0aN0ogRI9ShQwfNnj1b9evX17x586o8p6SkREOGDNHEiRN19tlne7FaoA5+/FG6/HJpxw7prLPskDznHNNVATgNo0FZWFioTZs2KSkpqWxfSEiIkpKStGHDhirPe/zxxxUbG6uRI0ee9nsUFBQoNze3wgsw4vHHpW+/ldq0sUOSP/IAv2B0eMjhw4dVUlKiuLi4Cvvj4uL0zTffVHrO+vXrNXfuXGVmZtboe6SlpWnixIl1LbUCJuaBS55+WioslMaPl9q2NV0NgBoyfuu1NvLy8jR06FDNmTNHzZo1q9E548aNU05OTtlr3759bqvH4WBuHpzG0aP2WElJioqS5s0jJAE/Y7RF2axZM4WGhio7O7vC/uzsbDVv3vyU47/77jvt2bNHAwcOLNvn/KVbfb169bRjxw6d8z/PfCIiIhQREeGB6oHT2LvX7rjzpz/Ziy7zhxXgl4y2KMPDw9WlSxelp6eX7XM6nUpPT1diYuIpx7dv315bt25VZmZm2euaa65Rnz59lJmZqfj4eG+WD1Rt7167486uXdKbb9otSwB+yfgUdqmpqRo+fLi6du2qbt26acaMGcrPz9eIESMkScOGDVOrVq2UlpamyMhIdezYscL5jRs3lqRT9gPGfP+93ZLcvdvusJORITVpYroqAC4yHpQpKSk6dOiQxo8fr6ysLHXu3FkrV64s6+Czd+9ehTAYG/5izx47JPfssYd+ZGTYkwoA8FsOy7KCqhNnbm6uYmJilJOTo+joaJe+RnbuSSVMTldoiEPfTe7v5grht/bssW+3fv+9PYlARobUqpXhogBUpaZ5QFMNcJdPP7WfTf7mN4QkEECM33oFAsaNN9o9W3v1klq2NF0NADchKIG62LVLatBAKp00IyXFbD0A3I5bry4Irqe6qNLOnVLv3lK/ftJpJvEH4L8Iyjpg+HgQ++9/7Y47P/xgryXJepJAwCIogdr69ls7JPfvlzp0kD74QKpkJikAgYGgBGpjxw47JH/8UbrwQmnduvLnkwACEkEJ1NSOHfZkAgcOSB07EpJAkCAogZqKjJQiIqSLLrJDMjbWdEUAvIDhIUBNlS64XL++dOaZpqsB4CW0KIHqbNsmvftu+XabNoQkEGQISqAqX39td9y5/npp9WrT1QAwhKAEKvOf/9gddw4dsp9Jdu1quiIAhhCULrDE1DwBbevW8pC89FJp7VqpaVPTVQEwhKCsAwdT8wSer76yQ/LwYalLF0ISAEEJlPn+e6lvX+mnn+xbrWvXSk2amK4KgGEMDwFKxcdLf/iD9OWXduedxo1NVwTABxCUQKmQEOm556Tjx6WGDU1XA8BHcOsVwW3zZmnUKKmoyN4OCSEkAVRAixLBa9MmKSlJOnpUatVKeuwx0xUB8EG0KBGcvviiPCS7d5dSU01XBMBHEZQIPp9/Xh6SPXpIK1dK0dGmqwLgowhKBJfPPrNDMidH6tlTev99qVEj01UB8GEEpQssJubxTydOSIMGSbm5Uq9ehCSAGiEo68AhpubxK1FR0muvSVdeKa1YQe9WADVCr1cEvuJiqd4vv+r9+tmz7zD/IIAaokWJwPbxx9IFF9jrSpYiJAHUAkGJwLV+vX2bdedOadIk09UA8FMEJQLTRx/ZIXnsmH27de5c0xUB8FMEJQLPhx9KV10l5efbQ0Hee0+qX990VQD8FEGJwPLvf5eH5O9/Ly1davd2BQAXEZQIHJZlP4s8flxKTpbefZeQBFBnBCUCh8Mhvf229MAD0jvvEJIA3IKgdAET8/iYffvKP46JkZ56SoqMNFcPgIBCUNYFw/HMW7tWOv98aepU05UACFAEJfzXmjXSwIH2HK7//rdUUmK6IgABiKCEf1q1yg7Jkyft/771lhQaaroqAAGIoIT/WblSuvZaqaBAuuYaOyQjIkxXBSBAEZTwL++/by+VVVBgh+WiRVJ4uOmqAAQwghL+ZccOOyT/8AfpzTcJSQAexzJb8C/33CO1bSsNGCCFhZmuBkAQoEUJ3/fBB1JOTvn2oEGEJACvISjh25YutaejS06W8vJMVwMgCBGULrAs5ubxinffla6/Xioqsm+3MiUdAAMIyjpgYh4PWrKkPCRvvFF67TWpHo/UAXgfQQnfs3ixdMMNUnGxNHiw9OqrhCQAYwhK+JZ335VSUuyQHDJEeuUVQhKAUbwDwbecd57UtKndeefFF5mWDoBxBCV8ywUXSJ9/LrVqRUgC8AnceoV5CxdK6enl22edRUgC8Bm0KGHWG29If/6zPan5xo1Sx46mKwKACmhRwpz58+2QdDrt3q0dOpiuCABOQVDCjNdek4YOtUPylluk55+XQvh1BOB7eGdyARPz1NGrr0rDhtkhOWqU9NxzhCQAn8W7Ux04mJqn9j74QBo+3P5r47bbpNmzCUkAPo3OPPCunj2l666TzjxTevZZQhKAzyMo4V1hYXZP15AQQhKAX+CdCp43d659m9XptLfr1SMkAfgNWpTwrDlzpFtvtT/u29eexxUA/IhP/Fk/c+ZMtW3bVpGRkUpISNDGjRurPHbOnDnq1auXmjRpoiZNmigpKana42HQ88+Xh+T//Z+9IggA+BnjQblw4UKlpqZqwoQJ2rx5szp16qTk5GQdPHiw0uMzMjJ000036YMPPtCGDRsUHx+vK664Qvv37/dy5ajWc8/Zt1slacwYacYMugkD8EsOyzI7KjAhIUGXXXaZnn32WUmS0+lUfHy87r77bo0dO/a055eUlKhJkyZ69tlnNWzYsNMen5ubq5iYGOXk5Cg6OtqlmvcdOa5eT32gyLAQfTPpKpe+RkCbNUu6807743vvlaZNIyQB+Jya5oHRFmVhYaE2bdqkpKSksn0hISFKSkrShg0bavQ1jh8/rqKiIjVt2rTSzxcUFCg3N7fCCx60Z4/dgpSk1FRCEoDfMxqUhw8fVklJieLi4irsj4uLU1ZWVo2+xoMPPqiWLVtWCNtfS0tLU0xMTNkrPj6+znWjGm3bSgsWSA8+KE2dSkgC8HvGn1HWxZQpU7RgwQItWbJEkZGRlR4zbtw45eTklL327dvntu/vECFQJj+//OM//lGaMoWQBBAQjAZls2bNFBoaquzs7Ar7s7Oz1bx582rPnTp1qqZMmaLVq1fr4osvrvK4iIgIRUdHV3jBzf7xD3t5rO+/N10JALid0aAMDw9Xly5dlP6rRXudTqfS09OVmJhY5XlPPfWUJk2apJUrV6pr167eKBVVmTHDfia5Z4+0aJHpagDA7YxPOJCamqrhw4era9eu6tatm2bMmKH8/HyNGDFCkjRs2DC1atVKaWlpkqS//e1vGj9+vObPn6+2bduWPcts2LChGjZsaOznCEpPP2132JGkhx+W7rvPbD0A4AHGgzIlJUWHDh3S+PHjlZWVpc6dO2vlypVlHXz27t2rkF9NdzZr1iwVFhbq+uuvr/B1JkyYoMcee8ybpQe3adOk+++3P37kEenxx3kmCSAgGR9H6W3uHEcZFRaq7ZOudHOFfmDqVOmBB+yPx4+XHnuMkATgd/xiHCX80IkT0ksv2R9PmCBNnEhIAghoxm+9ws9ERUnr1klvvy3dcYfpagDA42hRoma+/LL849hYQhJA0CAoXRBcT3UlPfmk1Lmzva4kAAQZgrIOguLR3KRJdq9WSapiRRcACGQEJao2caLdq1WS0tKkcePM1gMABtCZB5V77DE7KCXpb3+T/vpXo+UAgCkEJSqyLDskH3/c3n7qqfIxkwAQhAhKnKqoyP7v1KlMSwcg6BGUqMjhsHu59u8v9expuhoAMI7OPLBvt86bZ8+6I9lhSUgCgCSCEpZl92YdOVIaNEgqKTFdEQD4FG69BjPLkh58UPr73+3tgQOl0FCzNQGAjyEoXWApAKbmsSx7yMfUqfb2s89Ko0ebrQkAfBBBWQd+OzGPZdlrSU6fbm/PnCndeafZmgDARxGUwWj8+PKQnDVLuv12s/UAgA+jM08wuuYaqXFj6bnnCEkAOA1alMHossuknTulM84wXQkA+DxalMHAsqSxY6WNG8v3EZIAUCMEZaCzLOmuu+yJza+8Uvr5Z9MVAYBfISgDmdNpD/n417/s2XamTZOaNDFdFQD4FZ5RBiqn0x7y8dxzdki++KI0fLjpqgDA7xCUgcjptHuzzpljh+RLL0nDhpmuCgD8EkEZiGbOtEMyJER6+WXpz382XREA+C2C0gWWr89gd8st0ooVdkAOGWK6GgDwawRlHTgcPjSJndNp32Z1OKSoKDsofak+APBT9HoNBCUl9jJZ48aVN3cJSQBwC1qU/q40JF9+2V4i66abpE6dTFcFAAGDoPRnJSXSiBHSq6/aITl/PiEJAG5GUPqrkhLpL3+RXnvNDskFC6TrrzddFQAEHILSHxUX25MHzJ8v1atnh+R115muCgACEkHpj9avl954ww7JhQulP/7RdEUAELAISn90+eXSCy/Y87b+4Q+mqwGAgEZQ+oviYiknp3x5rJtvNlsPAAQJxlG6wOsT8xQVSYMHS7/7nZSd7e3vDgBBjRZlHXhlSH9RkT028u23pfBwaetWKS7OG98ZACCC0rcVFUk33igtXmyH5OLFUlKS6aoAIKgQlL6qsNAOySVL7JBcskTq3990VQAQdAhKX1RYKN1wg/Tuu1JEhPTOO9KVV5quCgCCEkHpi376SfrqKzsk331XSk42XREABC2C0he1aCF98IG0c6fUr5/pagAgqDE8xFcUFEgffli+3aYNIQkAPoCg9AUnT9rT0PXrZ3faAQD4DILStNKQXLFCCguTYmJMVwQA+BWeUbrAstw0N8/Jk/ZcrStXSlFR0rJlUt++7vnaAAC3ICjroi5T85w4IQ0aJK1eLdWvLy1fbk92DgDwKQSlCQUF0rXXSmvW2CG5YoXUu7fpqgAAleAZpQlhYVK7dlKDBtL77xOSAODDCEoTQkKkWbOkL76wVwQBAPgsgtJb8vOlJ5+0JzqX7LBs395sTQCA0+IZpTfk50tXXy1lZEjffSfNm2e6IgBADRGUnnbsmDRggD3rTqNG0qhRpisCANQCQelJx47ZS2N99JEUHS2tWiX99remqwIA1AJB6Sl5eXZIrl9vh+Tq1VJCgumqAAC1RGceF5x2Xh7Lkq6/3g7JmBh7vCQhCQB+iaCsgyon5nE4pAcftJfLWrNG6tbNm2UBANyIW6+e0rev3cM1Ksp0JQCAOqBF6S45Ofbcrdu2le8jJAHA79GidIecHCk5WfrsM+nbb6WtW6XQUNNVAQDcwCdalDNnzlTbtm0VGRmphIQEbdy4sdrjFy1apPbt2ysyMlIXXXSRVqxY4aVKK3H0qHTFFXZINm0qvf46IQkAAcR4UC5cuFCpqamaMGGCNm/erE6dOik5OVkHDx6s9PhPPvlEN910k0aOHKktW7Zo0KBBGjRokP7zn/94uXLZvVuvuELauFE64wxp3Trpkku8XwcAwGMclttWIXZNQkKCLrvsMj377LOSJKfTqfj4eN19990aO3bsKcenpKQoPz9fy5YtK9v329/+Vp07d9bs2bNP+/1yc3MVExOjnJwcRUdHu1Tzd4eOqd+0fyu66IS+mv4nOyTT06VOnVz6egAA76tpHhhtURYWFmrTpk1KSkoq2xcSEqKkpCRt2LCh0nM2bNhQ4XhJSk5OrvL4goIC5ebmVni5TUmJ1KyZ3ZIkJAEgIBkNysOHD6ukpERxcXEV9sfFxSkrK6vSc7Kysmp1fFpammJiYspe8fHx7ilesteVXLdOuvhi931NAIBPCfher+PGjVNqamrZdm5ubp3DskVMpF6/JUEhDod0zhl1LREA4MOMBmWzZs0UGhqq7OzsCvuzs7PVvHnzSs9p3rx5rY6PiIhQRESEewr+Rf3weupxbjO3fk0AgG8yeus1PDxcXbp0UXp6etk+p9Op9PR0JSYmVnpOYmJiheMlac2aNVUeDwBAXRi/9Zqamqrhw4era9eu6tatm2bMmKH8/HyNGDFCkjRs2DC1atVKaWlpkqQxY8aod+/emjZtmgYMGKAFCxboiy++0PPPP2/yxwAABCjjQZmSkqJDhw5p/PjxysrKUufOnbVy5cqyDjt79+5VSEh5w7d79+6aP3++HnnkET300EP6zW9+o3feeUcdO3Y09SMAAAKY8XGU3uaOcZQAAP/nF+MoAQDwdQQlAADVICgBAKgGQQkAQDUISgAAqkFQAgBQDYISAIBqEJQAAFSDoAQAoBoEJQAA1TA+16u3lc7Yl5uba7gSAIBJpTlwuplcgy4o8/LyJKnOizcDAAJDXl6eYmJiqvx80E2K7nQ69eOPP6pRo0ZyOBwuf53c3FzFx8dr3759TK7+K1yXqnFtKsd1qRrXpnLuui6WZSkvL08tW7assErV/wq6FmVISIhat27ttq8XHR3NL3AluC5V49pUjutSNa5N5dxxXaprSZaiMw8AANUgKAEAqAZB6aKIiAhNmDBBERERpkvxKVyXqnFtKsd1qRrXpnLevi5B15kHAIDaoEUJAEA1CEoAAKpBUAIAUA2CEgCAahCU1Zg5c6batm2ryMhIJSQkaOPGjdUev2jRIrVv316RkZG66KKLtGLFCi9V6l21uS5z5sxRr1691KRJEzVp0kRJSUmnvY7+rLa/M6UWLFggh8OhQYMGebZAQ2p7XY4eParRo0erRYsWioiI0Hnnnce/p1/MmDFD559/vqKiohQfH697771XJ0+e9FK13vHhhx9q4MCBatmypRwOh955553TnpORkaFLL71UEREROvfcc/XSSy+5ryALlVqwYIEVHh5uzZs3z/r666+tUaNGWY0bN7ays7MrPf7jjz+2QkNDraeeesratm2b9cgjj1hhYWHW1q1bvVy5Z9X2ugwePNiaOXOmtWXLFmv79u3WX/7yFysmJsb64YcfvFy559X22pTavXu31apVK6tXr17Wtdde651ivai216WgoMDq2rWr1b9/f2v9+vXW7t27rYyMDCszM9PLlXteba/N66+/bkVERFivv/66tXv3bmvVqlVWixYtrHvvvdfLlXvWihUrrIcffthavHixJclasmRJtcfv2rXLql+/vpWammpt27bN+uc//2mFhoZaK1eudEs9BGUVunXrZo0ePbpsu6SkxGrZsqWVlpZW6fE33HCDNWDAgAr7EhISrNtuu82jdXpbba/L/youLrYaNWpkvfzyy54q0RhXrk1xcbHVvXt364UXXrCGDx8ekEFZ2+sya9Ys6+yzz7YKCwu9VaIxtb02o0ePtvr27VthX2pqqtWjRw+P1mlSTYLyr3/9q3XhhRdW2JeSkmIlJye7pQZuvVaisLBQmzZtUlJSUtm+kJAQJSUlacOGDZWes2HDhgrHS1JycnKVx/sjV67L/zp+/LiKiorUtGlTT5VphKvX5vHHH1dsbKxGjhzpjTK9zpXrsnTpUiUmJmr06NGKi4tTx44dNXnyZJWUlHirbK9w5dp0795dmzZtKrs9u2vXLq1YsUL9+/f3Ss2+ytPvv0E3KXpNHD58WCUlJYqLi6uwPy4uTt98802l52RlZVV6fFZWlsfq9DZXrsv/evDBB9WyZctTfqn9nSvXZv369Zo7d64yMzO9UKEZrlyXXbt2ad26dRoyZIhWrFihnTt36s4771RRUZEmTJjgjbK9wpVrM3jwYB0+fFg9e/aUZVkqLi7W7bffroceesgbJfusqt5/c3NzdeLECUVFRdXp69OihNdMmTJFCxYs0JIlSxQZGWm6HKPy8vI0dOhQzZkzR82aNTNdjk9xOp2KjY3V888/ry5duiglJUUPP/ywZs+ebbo04zIyMjR58mT961//0ubNm7V48WItX75ckyZNMl1aQKNFWYlmzZopNDRU2dnZFfZnZ2erefPmlZ7TvHnzWh3vj1y5LqWmTp2qKVOmaO3atbr44os9WaYRtb023333nfbs2aOBAweW7XM6nZKkevXqaceOHTrnnHM8W7QXuPI706JFC4WFhSk0NLRs3wUXXKCsrCwVFhYqPDzcozV7iyvX5tFHH9XQoUN1yy23SJIuuugi5efn69Zbb9XDDz9c7ZqKgayq99/o6Og6tyYlWpSVCg8PV5cuXZSenl62z+l0Kj09XYmJiZWek5iYWOF4SVqzZk2Vx/sjV66LJD311FOaNGmSVq5cqa5du3qjVK+r7bVp3769tm7dqszMzLLXNddcoz59+igzM1Px8fHeLN9jXPmd6dGjh3bu3Fn2h4Mkffvtt2rRokXAhKTk2rU5fvz4KWFY+geFFcTTdnv8/dctXYIC0IIFC6yIiAjrpZdesrZt22bdeuutVuPGja2srCzLsixr6NCh1tixY8uO//jjj6169epZU6dOtbZv325NmDAhYIeH1Oa6TJkyxQoPD7feeust68CBA2WvvLw8Uz+Cx9T22vyvQO31WtvrsnfvXqtRo0bWXXfdZe3YscNatmyZFRsbaz3xxBOmfgSPqe21mTBhgtWoUSPrjTfesHbt2mWtXr3aOuecc6wbbrjB1I/gEXl5edaWLVusLVu2WJKs6dOnW1u2bLG+//57y7Isa+zYsdbQoUPLji8dHvLAAw9Y27dvt2bOnMnwEG/55z//aZ111llWeHi41a1bN+vTTz8t+1zv3r2t4cOHVzj+zTfftM477zwrPDzcuvDCC63ly5d7uWLvqM11adOmjSXplNeECRO8X7gX1PZ35tcCNSgtq/bX5ZNPPrESEhKsiIgI6+yzz7aefPJJq7i42MtVe0dtrk1RUZH12GOPWeecc44VGRlpxcfHW3feeaf1888/e79wD/rggw8qfd8ovRbDhw+3evfufco5nTt3tsLDw62zzz7bevHFF91WD8tsAQBQDZ5RAgBQDYISAIBqEJQAAFSDoAQAoBoEJQAA1SAoAQCoBkEJAEA1CEoAAKpBUAIB4i9/+YscDscpr507d1b4XHh4uM4991w9/vjjKi4ulmSvSvHrc84880z1799fW7duNfxTAeYRlEAAufLKK3XgwIEKr3bt2lX43H//+1/dd999euyxx/T3v/+9wvk7duzQgQMHtGrVKhUUFGjAgAEqLCw08aMAPoOgBAJIRESEmjdvXuFVurpE6efatGmjO+64Q0lJSVq6dGmF82NjY9W8eXNdeumluueee7Rv374aL8oNBCqCEghSUVFRVbYWc3JytGDBAkkKqKWtAFewcDMQQJYtW6aGDRuWbV911VVatGhRhWMsy1J6erpWrVqlu+++u8LnWrduLUnKz8+XJF1zzTVq3769h6sGfBtBCQSQPn36aNasWWXbDRo0KPu4NESLiorkdDo1ePBgPfbYYxXO/+ijj1S/fn19+umnmjx5smbPnu2t0gGfRVACAaRBgwY699xzK/1caYiGh4erZcuWqlfv1H/+7dq1U+PGjXX++efr4MGDSklJ0YcffujpsgGfxjNKIEiUhuhZZ51VaUj+r9GjR+s///mPlixZ4oXqAN9FUAKoVP369TVq1ChNmDBBrO+OYEZQAqjSXXfdpe3bt5/SIQgIJg6LPxUBAKgSLUoAAKpBUAIAUA2CEgCAahCUAABUg6AEAKAaBCUAANUgKAEAqAZBCQBANQhKAACqQVACAFANghIAgGoQlAAAVOP/AWFdnkK3L++YAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 500x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.plot([0,1],[0,1],'r--')\n",
    "plt.plot(lr_model.summary.roc.select(\"FPR\").collect(),lr_model.summary.roc.select(\"TPR\").collect())\n",
    "plt.xlabel(\"FPR\")\n",
    "plt.ylabel(\"TPR\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "21f4adb2-a2c2-4a83-9520-566a5421bde5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{Param(parent='LogisticRegression_d565ffd21c91', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.'): 0.0},\n",
       " {Param(parent='LogisticRegression_d565ffd21c91', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.'): 1.0}]"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.tuning import ParamGridBuilder\n",
    "\n",
    "grid_search = (\n",
    "    ParamGridBuilder()\n",
    "    .addGrid(lr.elasticNetParam,[0.0,1.0])\n",
    "    .build()\n",
    ")\n",
    "grid_search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "6511e7f5-06dc-4514-8957-6d2f8b86cfbb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.9896154286859112, 0.9896155418021585]"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from pyspark.ml.tuning import CrossValidator\n",
    "\n",
    "cv = CrossValidator(estimator=food_pipeline,\n",
    "                    estimatorParamMaps = grid_search,\n",
    "                    evaluator=evaluator,\n",
    "                    numFolds=3,seed=13,\n",
    "                    collectSubModels=True)\n",
    "\n",
    "cv_model = cv.fit(train)\n",
    "cv_model.avgMetrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "8cb89e3d-bb88-4116-873b-19cb5dcfb107",
   "metadata": {},
   "outputs": [],
   "source": [
    "pipeline_food_model = cv_model.bestModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "03d8ea5c-348c-41ad-891b-985bf1af16f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "feature_names = [\"Intercept\"] + [x[\"name\"] for x in food_features.schema[\"features\"].metadata[\"ml_attr\"][\"attrs\"][\"numeric\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "7b029cd3-ff9a-4787-80e3-d8b0ff34defe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                         coef  abs_coef\n",
      "Intercept           -2.093416  2.093416\n",
      "seafood             -3.242796  3.242796\n",
      "avocado             -0.823192  0.823192\n",
      "snapper             -3.974354  3.974354\n",
      "broil               -0.871285  0.871285\n",
      "...                       ...       ...\n",
      "continuous_scaled_2 -0.327230  0.327230\n",
      "continuous_scaled_3 -1.384267  1.384267\n",
      "continuous_scaled_4 -2.429981  2.429981\n",
      "protein_ratio       -8.803789  8.803789\n",
      "fat_ratio            0.470149  0.470149\n",
      "\n",
      "[514 rows x 2 columns]\n"
     ]
    }
   ],
   "source": [
    "feature_coefficients = [lr_model.intercept] + list(lr_model.coefficients.values)\n",
    "coefficients = pd.DataFrame(feature_coefficients,index=feature_names,columns=[\"coef\"])\n",
    "coefficients[\"abs_coef\"] = coefficients[\"coef\"].abs()\n",
    "print(coefficients)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "8760a4d2-5351-42e5-9621-a34c6b4a6d93",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                     coef   abs_coef\n",
      "sangria        -22.831494  22.831494\n",
      "cauliflower    -16.423210  16.423210\n",
      "horseradish    -14.847373  14.847373\n",
      "arugula        -13.897515  13.897515\n",
      "mustard_greens -13.507600  13.507600\n",
      "...                   ...        ...\n",
      "radicchio       -0.017067   0.017067\n",
      "rutabaga         0.009016   0.009016\n",
      "self            -0.005587   0.005587\n",
      "quince          -0.003818   0.003818\n",
      "bon_apptit      -0.000964   0.000964\n",
      "\n",
      "[514 rows x 2 columns]\n"
     ]
    }
   ],
   "source": [
    "print(coefficients.sort_values([\"abs_coef\"],ascending=False))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "c25dd0a5-b145-48aa-af45-b521e5bfa437",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Index(['lunar_new_year', 'snack_week', 'barley', 'ramekin', 'new_mexico',\n",
      "       'pittsburgh', 'shavuot', 'new_york', 'whiskey', 'diwali',\n",
      "       ...\n",
      "       'chicken', 'washington_dc', 'low_sugar', 'coffee_grinder',\n",
      "       'leafy_green', 'mustard_greens', 'arugula', 'horseradish',\n",
      "       'cauliflower', 'sangria'],\n",
      "      dtype='object', length=514)\n"
     ]
    }
   ],
   "source": [
    "print(coefficients.sort_values([\"coef\"],ascending=False).index)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "860ed176-b605-4708-a2c0-dde13bfbbc7e",
   "metadata": {},
   "source": [
    "https://www.zhihu.com/question/36853661/answer/2780428963\n",
    "\n",
    "**概念**\n",
    "\n",
    "PDF：概率密度函数（probability density function）, 在数学中，连续型随机变量的概率密度函数（在不至于混淆时可以简称为密度函数）是一个描述这个随机变量的输出值，在某个确定的取值点附近的可能性的函数。\n",
    "\n",
    "PMF：概率质量函数（probability mass function), 在概率论中，概率质量函数是离散型随机变量在各特定取值上的概率。\n",
    "\n",
    "CDF：累积分布函数 (cumulative distribution function) ，又叫分布函数，是概率密度函数的积分，能完整描述一个实随机变量X的概率分布\n",
    "\n",
    "CCDF：1-CDF\n",
    "\n",
    "**区别**\n",
    "\n",
    "PDF是连续变量特有的，PMF是离散随机变量特有的。\n",
    "\n",
    "PDF的取值本身不是概率，它是一种趋势（密度）只有对连续随机变量的取值进行积分后才是概率，也就是说对于连续值确定它在某一点的概率是没有意义的。\n",
    "\n",
    "PMF的取值本身代表该值的概率。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02d05bf9-68c4-4581-be9e-b1badc7f86b4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
